Add self-guide mode for CachedGISTEmbedLoss#3662
Add self-guide mode for CachedGISTEmbedLoss#3662yjoonjang wants to merge 6 commits intohuggingface:mainfrom
Conversation
|
I also thought about making a Thank you ! |
|
Hello! This is very cool! Have you been able to run some training tests with this? I imagine it might work pretty nicely compared to e.g. MultipleNegativesRankingLoss.
|
|
Hi @tomaarsen, thank you for your comment ! I ran some experiments comparing - You can find the details in my colab. While the performance gain is modest, the results look promising
Wonder why the test results aren't exactly same for the two GIST experiments. |
|
Thanks for running and sharing your results, very nice! I imagine a difference between your experiment 2 (self-guided) and experiment 3 (guided with mpnet-base) is that in experiment 2 the guide model is continuously being updated. There is a bit of a risk with the self-guiding that the model learns to consider a lot of samples as false negatives so those can be ignored. I'm glad to see that the self-guided saves a lot of training time though, that's very helpful. What was the margin of -1.0 based on? I don't fully remember what
|
|
Ah thanks for the clarification. I forgot that the guide model is being updated 🤣 Answer to your question - how does the margin work?The filtering condition in the code is: # absolute strategy
mask = guided_sim_mat > (guided_sim - margin)Where Why negative margin for self-guided?In the original GISTEmbed, the guide is a strong, frozen teacher. Its similarity scores are reliable, so we trust it to identify false negatives ( In self-guided mode, the guide is the model itself, which is still being trained and initially unreliable. If we use Corrected experimentI realized I made a mistake in my previous experiment. I was using The performance gain I observed in the previous experiment was likely due to the difference in default temperature between the two losses (MNRL: I re-ran the experiment with
The self-guided mode matches the original GISTEmbed accuracy while being ~1.8x faster, nearly matching MNRL's training speed. This speedup is expected since self-guided mode eliminates the second forward pass through the guide model. However, I should note that the accuracy improvement over MNRL is likely still driven by the temperature difference rather than the self-guided filtering itself. The results are very similar to the previous experiment with margin=-1.0 (where no filtering occurred), which suggests that at this stage, the student model is too weak to produce meaningful similarity estimates for effective false negative filtering. Since the model is still early in training, its similarity scores are essentially noise. So even with margin=-0.1, the filtering is not yet contributing a useful signal. We expect self-guided filtering to become more effective when starting from a stronger base model or later in training, where the model's similarity estimates are more reliable. A more comprehensive evaluation at larger scale with a stronger student would be needed to isolate the true benefit of self-guided filtering from the temperature effect. Precedent from recent workThe idea of using the model's own similarity scores for in-batch false negative filtering with a tolerance margin is well-established in recent embedding papers:
All three confirm that when using the model's own scores for filtering, a positive tolerance (relaxed threshold) is necessary to avoid over-filtering. Our self-guided mode with |
|
For context, we had a related discussion in #3665 about the broader direction of online vs. offline distillation/guidance. Here's a quick summary of the relevant points for this PR:
Linking here so the context is easy to find for anyone following this PR. |
|
I'll run some test with the new implementation. |
|
Thanks! My guess is that it doesn't work great with a base model (bert-base-uncased, mpnet-base, ModernBERT-base), but does work nicely with a model that's already an embedding model.
|
|
I've ran experiments some experiments. ExperimentsSetup
Results (NanoBEIR-full)wandb: [link]
I couldn't find a solid trend here, but it looks Warmup Mechanism VerificationTo verify the warmup step, I've added some logs (just for this experiment) to the callback. With
With
In conclusion,
|
|
Should I run some extra experiments (e.g. |
|
I think that would be useful. GIST works pretty nicely, and this is similar, so I imagine there's some settings for which it works.
|
|
Happy to share some good news.
|
|
Hi @tomaarsen, could you please take a look at the results when you have a moment ? |
|
Thanks for the detailed benchmarks. It's nice to see that there's more gains when the
|
|
Following your recommendation, I ran a margin sweep on Setup
Training script#!/bin/bash
GPUS="${GPUS:-4,5,6,7}"
NUM_GPUS=$(echo "$GPUS" | tr ',' '\n' | wc -l)
MODEL_NAME="Alibaba-NLP/gte-multilingual-base"
SHORT_MODEL_NAME="gte-multilingual-base"
mkdir -p $SHORT_MODEL_NAME-logs
# 1. No self-guide (baseline CachedMNRL)
echo "=== [1/9] No self-guide (baseline) ==="
CUDA_VISIBLE_DEVICES=$GPUS torchrun \
--nproc_per_node=$NUM_GPUS --master_port 29502 \
train.py \
--model_name $MODEL_NAME \
--query_prefix "" \
--doc_prefix "" \
--per_device_train_batch_size 512 \
--mini_batch_size 64 \
> $SHORT_MODEL_NAME-logs/no_sg.log 2>&1
# 2. Self-guide
RUN_IDX=2
for MARGIN in -0.1 -0.2 -0.25 -0.3; do
for WARMUP in 0.2 0; do
if [ "$WARMUP" == "0.2" ]; then
WARMUP_LABEL="w/ warmup"
else
WARMUP_LABEL="w/o warmup"
fi
echo "=== [$RUN_IDX/9] margin=${MARGIN}, ${WARMUP_LABEL} ==="
CUDA_VISIBLE_DEVICES=$GPUS torchrun \
--nproc_per_node=$NUM_GPUS --master_port 29502 \
train.py \
--model_name $MODEL_NAME \
--self_guide \
--self_guide_warmup_ratio $WARMUP \
--self_guide_margin=$MARGIN \
--self_guide_margin_strategy absolute \
--query_prefix "" \
--doc_prefix "" \
--per_device_train_batch_size 512 \
--mini_batch_size 64 \
> $SHORT_MODEL_NAME-logs/sg_warmup_${WARMUP}_margin_${MARGIN}_absolute.log 2>&1
RUN_IDX=$((RUN_IDX + 1))
done
done
echo "=== All 9 runs completed ==="Training Python fileimport argparse
import logging
import os
from datasets import load_dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import NanoBEIREvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="microsoft/mpnet-base")
# Self-guide arguments
parser.add_argument("--self_guide", action="store_true", help="Enable self-guided false-negative filtering")
parser.add_argument("--self_guide_margin", type=float, default=0.0, help="Margin for false-negative detection threshold")
parser.add_argument("--self_guide_margin_strategy", type=str, default="absolute", choices=["absolute", "relative"], help="Strategy for applying the margin")
parser.add_argument("--self_guide_warmup_ratio", type=float, default=0.0, help="Fraction of training steps to disable filtering (warmup)")
# Hardness weighting arguments (can be combined with self-guide)
parser.add_argument("--hardness_mode", type=str, default="none", choices=["none", "in_batch_negatives", "hard_negatives", "all_negatives"])
parser.add_argument("--hardness_strength", type=float, default=0.0)
# Prefix arguments (e.g., for e5 models: --query_prefix "query: " --doc_prefix "passage: ")
parser.add_argument("--query_prefix", type=str, default="", help="Prefix prepended to anchor/query texts")
parser.add_argument("--doc_prefix", type=str, default="", help="Prefix prepended to positive/negative texts")
# Training arguments
parser.add_argument("--per_device_train_batch_size", type=int, default=256)
parser.add_argument("--mini_batch_size", type=int, default=32)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--output_dir", type=str, default="models/self_guide")
parser.add_argument("--wandb_project", type=str, default="Self_Guide")
return parser.parse_args()
def main():
args = parse_args()
hardness_mode = None if args.hardness_mode == "none" else args.hardness_mode
short_model_name = args.model_name.split("/")[-1]
# Build run name
parts = [short_model_name]
if args.self_guide:
sg_name = "sg"
if args.self_guide_margin != 0.0:
sg_name += f"_m{args.self_guide_margin}_{args.self_guide_margin_strategy}"
sg_name += f"_w{args.self_guide_warmup_ratio}"
parts.append(sg_name)
if hardness_mode is not None:
parts.append(f"{args.hardness_mode}_s{args.hardness_strength}")
if not parts:
parts.append("baseline")
run_name = "_".join(parts)
logger.info(
f"Run: {run_name} | self_guide={args.self_guide}, margin={args.self_guide_margin}, "
f"margin_strategy={args.self_guide_margin_strategy}, warmup_ratio={args.self_guide_warmup_ratio}, "
f"hardness_mode={hardness_mode}, hardness_strength={args.hardness_strength}"
)
# 1. Load model
model = SentenceTransformer(args.model_name, trust_remote_code=True)
model.max_seq_length = 512
# 2. Load dataset
dataset = load_dataset("tomaarsen/natural-questions-hard-negatives", "triplet-5", split="train")
dataset = dataset.rename_columns({"query": "anchor", "answer": "positive"})
# Apply query/doc prefixes if specified
if args.query_prefix or args.doc_prefix:
def add_prefixes(example):
for key in example:
if key == "anchor":
example[key] = args.query_prefix + example[key]
else:
example[key] = args.doc_prefix + example[key]
return example
dataset = dataset.map(add_prefixes)
logger.info(f"Applied prefixes: query='{args.query_prefix}', doc='{args.doc_prefix}'")
logger.info(f"Dataset size: {len(dataset)}")
# 3. Define loss
loss = CachedMultipleNegativesRankingLoss(
model,
mini_batch_size=args.mini_batch_size,
gather_across_devices=False,
hardness_mode=hardness_mode,
hardness_strength=args.hardness_strength,
self_guide=args.self_guide,
self_guide_margin=args.self_guide_margin,
self_guide_margin_strategy=args.self_guide_margin_strategy,
self_guide_warmup_ratio=args.self_guide_warmup_ratio,
)
# 4. Evaluator (subset of NanoBEIR for mid-training eval)
evaluator_kwargs = {}
if args.query_prefix:
evaluator_kwargs["query_prompts"] = args.query_prefix
if args.doc_prefix:
evaluator_kwargs["corpus_prompts"] = args.doc_prefix
dev_evaluator = NanoBEIREvaluator(
dataset_names=["msmarco", "nq", "nfcorpus", "quoraretrieval"],
show_progress_bar=True,
batch_size=64,
**evaluator_kwargs,
)
# 5. Training arguments
os.environ["WANDB_PROJECT"] = args.wandb_project
training_args = SentenceTransformerTrainingArguments(
output_dir=f"{args.output_dir}/{run_name}",
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
learning_rate=args.learning_rate,
warmup_ratio=0.1,
bf16=True,
eval_on_start=True,
eval_strategy="steps",
eval_steps=0.2,
logging_steps=10,
logging_first_step=True,
save_strategy="no",
report_to="wandb",
run_name=run_name,
ddp_find_unused_parameters=True,
)
# 6. Train
trainer = SentenceTransformerTrainer(
model=model,
args=training_args,
train_dataset=dataset,
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# 7. Full NanoBEIR evaluation (all 13 subsets)
logger.info("Running full NanoBEIR evaluation on all subsets...")
full_evaluator = NanoBEIREvaluator(
show_progress_bar=True,
batch_size=64,
**evaluator_kwargs,
)
full_evaluator(model)
# 8. Save
model.save_pretrained(f"{args.output_dir}/{run_name}/final")
logger.info(f"Done: {run_name}")
if __name__ == "__main__":
main()Results
The best is Would be happy if these results help. |
|
Hi @tomaarsen, apologies for tagging you bunch of times. |
|
Thank you! I'm surprised at how small the gaps are. Very interesting. I think perhaps I'll have to do more research on this. However, I want to try and push #3554 first as there's a lot of demand there.
|
Summary
Hello ! This PR optimizes
CachedGISTEmbedLossfor self-guided training scenarios where the student model serves as its own guide.Motivation
Recent embedding papers (e.g., Qwen3-Embedding, Diffusion-Pretrained Dense and Contextual Embeddings (a.k.a pplx-embed)) have shown that using the model's own similarity scores with a margin (e.g.,
margin=-1.0) provides effective self-guide without requiring a separate guide model.[Qwen3-Embedding]
[pplx-embed]
This approach:
student_score > positive_score - margin (0.1)However, when using the model as its own guide, the previous implementation still required passing the same instance twice and performed two forward passes - one for student embeddings and one for guide embeddings. This is computationally wasteful since both would produce identical results.
Changes
guideparameter is now optional: Whenguide=None(default), the model uses itself as the guide (self-guided mode)reps.detach()asguide_repsinstead of calling the guide model againmust_retokenizeis alwaysFalse(same tokenizer)Code Changes
Usage Example
Benefits
guide=None)guide=model)